import torch
def loss(image, label, mean, variance, beta=1):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    mse = torch.nn.MSELoss()
    mse_loss = mse(image, label)
    
    # 计算KL散度部分
    kl_divergence = -0.5 * torch.sum(1 + torch.log(variance) - mean.pow(2) - variance)
    
    # 总损失
    total_loss = mse_loss + beta * kl_divergence
    return total_loss

